#!/usr/bin/env python

import gzip
import itertools
import os
import sys

from uta.formats.exonset import ExonSet, ExonSetReader
from uta.formats.txinfo import TxInfo, TxInfoReader


def read_exonsets(fn):
    """return multimap of ac: [exonset records] read from given file"""
    exonsets = [es for es in ExonSetReader(gzip.open(fn))
                if es.tx_ac.startswith("NM") and es.alt_ac.startswith("NC_0000")]
    exonsets.sort(key=lambda es: es.tx_ac)
    return {tx_ac: list(esi)
            for tx_ac, esi in itertools.groupby(exonsets, key=lambda es: es.tx_ac)}


def es_eq(es1, es2):
    failures = []
    if es1.tx_ac != es2.tx_ac:
        failures += ["tx_ac"]
    if es1.alt_ac != es2.alt_ac:
        failures += ["alt_ac"]
    if es1.strand != es2.strand:
        failures += ["str"]
    if len(es1.exons_se_i) != len(es2.exons_se_i):
        failures += ["nex"]
    if es1.exons_se_i != es2.exons_se_i:
        failures += ["exse"]
    return ";".join(failures)


def is_xy(es):
    return es.alt_ac.split(".")[0] in ("NC_000023", "NC_000024")

if __name__ == "__main__":
    es1 = read_exonsets(sys.argv[1])
    es2 = read_exonsets(sys.argv[2])

    es1_acs = set(es1.keys())
    es2_acs = set(es2.keys())

    es1_not_es2 = es1_acs - es2_acs
    es2_not_es1 = es2_acs - es1_acs
    es1_int_es2 = es1_acs & es2_acs

    print("{} in file 1, {} unique".format(len(es1), len(es1_not_es2)))
    print("{} in file 2, {} unique".format(len(es2), len(es2_not_es1)))
    print("{} in common".format(len(es1_int_es2)))

    mismatch = set()
    non_par = set()
    par = set()
    for tx_ac in es1_int_es2:
        if len(es1[tx_ac]) > 1 or len(es2[tx_ac]) > 1:
            all_xy = all(is_xy(es) for es in es1[tx_ac] + es2[tx_ac])
            print("{}: degenerate alignments; all XY={}".format(tx_ac, all_xy))
            if all_xy:
                par.add(tx_ac)
            else:
                non_par.add(tx_ac)
            continue
        failures = es_eq(es1[tx_ac][0], es2[tx_ac][0])
        if failures:
            print("{}: {}!".format(tx_ac, failures))
            mismatch.add((tx_ac, failures))

    import IPython
    IPython.embed()  # TODO: Remove IPython.embed()
